import json
import numpy as np
import pickle
import pandas as pd
import igraph
import networkx as nx
import matplotlib.pyplot as plt



def read_dataset(file_name):
    all_graph_events = []
    all_graph_event_links = []
    f = open(file_name + '.json')
    data = json.load(f)
    schemas = data['schemas']
    all_schemas_event_graphs = []
    for ii in range(len(schemas)):
        curr_events = collect_events(schemas[ii]['steps'])
        curr_event_links = collect_event_links(schemas[ii]['order'])
        print(curr_events)
        print(curr_event_links)
        events_ontology = _load_event_ontology()
        print(events_ontology)
        event_name_to_id = _construct_event_name_to_id_dict(
            curr_events, events_ontology)
        event_graphs_all, event_graphs_no_iso = _construct_event_graph(
            curr_event_links, event_name_to_id, curr_events)
        all_schemas_event_graphs.append(event_graphs_all)
    print(all_schemas_event_graphs)
    for ii in range(len(all_schemas_event_graphs)):
        print(len(list(all_schemas_event_graphs[ii].vs)))
        print(len(list(all_schemas_event_graphs[ii].es)))
    with open(file_name + '_igraphs.pkl', 'wb') as handle:
        pickle.dump(all_schemas_event_graphs, handle)







def collect_events(events):
    all_events = []
    all_event_entity_relations = []
    for ii in range(len(events)):
        event_name = events[ii]['@id']
        event_type = events[ii]['@type'].split('/')[-1]
        all_events.append((event_name, event_type))
    return all_events

def collect_event_links(events_orders):
    all_event_links = []
    for ii in range(len(events_orders)):
        start_events = events_orders[ii]['before']
        end_events = events_orders[ii]['after']
        if type(start_events) != type([]):
            start_events = [start_events]
        if type(end_events) != type([]):
            end_events = [end_events]
        for start_event in start_events:
            for end_event in end_events:
                all_event_links.append((start_event, end_event))
    return all_event_links



def _load_event_ontology():
    saved_dict = pickle.load(open("./data/kairos_ontology.pkl", "rb"))[0]
    event_types_ontology = saved_dict['event_types']
    event_types_ontology_new = {"START": 0, "END": 1}
    for key, val in event_types_ontology.items():
        event_types_ontology_new[key] = val + 2
    return event_types_ontology_new

def _construct_event_name_to_id_dict(all_graph_events, events_ontology):
    event_dict = {}
    for ii in range(len(all_graph_events)):
        curr_entity = all_graph_events[ii]
        event_dict[curr_entity[0]] = events_ontology[curr_entity[1]]
    return event_dict

def _get_graph_node_num(graph_list):
    all_nodes = set()
    graph_node_ind_dict = {}
    graph_all_nodes = nx.DiGraph()
    for ii in range(len(graph_list)):
        curr_start = graph_list[ii][0]
        curr_end = graph_list[ii][1]
        graph_all_nodes.add_edge(curr_start, curr_end)
    graph_without_iso = graph_all_nodes.copy()
    isolated_events = list(nx.isolates(graph_without_iso))
    graph_without_iso.remove_nodes_from(isolated_events)
    graph_all_nodes_list = list(graph_all_nodes.nodes)
    graph_all_nodes_name_to_id_dict = {}
    for ii in range(len(graph_all_nodes_list)):
        graph_all_nodes_name_to_id_dict[graph_all_nodes_list[ii]] = ii + 1
    graph_without_iso_list = list(graph_without_iso.nodes)
    graph_without_iso_name_to_id_dict = {}
    for ii in range(len(graph_without_iso_list)):
        graph_without_iso_name_to_id_dict[graph_without_iso_list[ii]] = ii + 1
    return graph_all_nodes_name_to_id_dict, graph_without_iso_name_to_id_dict


def _get_nodes_for_start_end(graph, node_num):
    indegree_0 = []
    outdegree_0 = []
    for ii in range(1, node_num + 1):
        if graph.vs[ii].indegree() == 0:
            indegree_0.append(ii)
        if graph.vs[ii].outdegree() == 0:
            outdegree_0.append(ii)
    return set(indegree_0), set(outdegree_0)

def _reorder_graph(graph):
    new_order = graph.topological_sorting(mode='out')
    new_g = igraph.Graph(directed=True)
    new_g.add_vertices(len(list(graph.vs)))
    new_order_dict = {}
    for ii in range(len(new_order)):
        new_g.vs[ii]['type'] = graph.vs[new_order[ii]]['type']
        new_order_dict[new_order[ii]] = ii
    all_prev_edges = list(graph.es)
    for ii in range(len(all_prev_edges)):
        curr_edge = all_prev_edges[ii]
        new_g.add_edge(new_order_dict[curr_edge.source],
            new_order_dict[curr_edge.target])
    return new_g



def _construct_and_reorder(graph_name_to_id, event_dict, event_links):
    g = igraph.Graph(directed=True)
    n = len(graph_name_to_id)
    g.add_vertices(n + 2)
    g.vs[0]['type'] = 0
    g.vs[n + 1]['type'] = 1
    for ii in range(len(event_links)):
        if (event_links[ii][0] not in graph_name_to_id or
            event_links[ii][1] not in graph_name_to_id):
            continue
        start_node_ind = graph_name_to_id[event_links[ii][0]]
        start_node_type = event_dict[event_links[ii][0]]
        end_node_ind = graph_name_to_id[event_links[ii][1]]
        end_node_type = event_dict[event_links[ii][1]]
        g.vs[start_node_ind]['type'] = start_node_type
        g.vs[end_node_ind]['type'] = end_node_type
        g.add_edge(start_node_ind, end_node_ind)
    indegree_0, outdegree_0 = _get_nodes_for_start_end(
        g, len(graph_name_to_id))
    for ii in range(1, n + 1):
        if ii in indegree_0:
            g.add_edge(0, ii)
        if ii in outdegree_0:
            g.add_edge(ii, n + 1)
    g = _reorder_graph(g)
    return g




def _construct_event_graph(all_graph_event_links, event_dict, all_graph_events):
    curr_graph = all_graph_event_links
    graph_all_name_to_id, graph_no_iso_name_to_id = _get_graph_node_num(
        curr_graph)
    g_all = _construct_and_reorder(
        graph_all_name_to_id, event_dict, curr_graph)
    g_no_iso = _construct_and_reorder(
        graph_no_iso_name_to_id, event_dict, curr_graph)
    if len(list(g_all.vs)) > len(list(g_no_iso.vs)):
        print("not equal")
    return g_all, g_no_iso





file_name = "./data/RESIN_schema/resin-schemalib"
read_dataset(file_name)















